# rm(list = ls());
gc()

# source("00_pckgs.R")

options(mc.cores = parallel::detectCores())

# set.seed(1917)

# full_sessions <- 1:5
# run <- "first"
# toggle_nonsep <- 0

load("data/ep/full_ep.Rda")

# IRT Models ####
if (toggle_nonsep == 0){
  if (!file.exists("results/ep/sirt.txt")){
    file.create("results/ep/sirt.txt")
    sessions <- 1:5
    run <- "first"
  } else {
    done <- read.table("results/ep/sirt.txt") %>% 
      pull(V2)
    if (length(done) < 5){
      sessions <- c((length(done)+1):5)
      run <- "first"
    } else {
      sessions <- read.table("results/ep/sirt.txt") %>%
        group_by(V2) %>%
        summarize(V5 = min(V5)) %>%
        filter(V5 >= 1.1) %>%
        pull(V2)
      run <- "not-first"
    }
  }
} else {
  if (!file.exists("results/ep/nsirt.txt")){
    file.create("results/ep/nsirt.txt")
    sessions <- 1:5
    run <- "first"
  } else {
    done <- read.table("results/ep/nsirt.txt") %>% 
      pull(V2)
    if (length(done) < 5){
      sessions <- c((length(done)+1):5)
      run <- "first"
    } else {
      sessions <- read.table("results/ep/nsirt.txt") %>%
        group_by(V2) %>%
        summarize(V5 = min(V5)) %>%
        filter(V5 >= 1.1) %>%
        pull(V2)
      run <- "not-first"
    }
  }
}


# if (run == "first"){
#   if (crash_check == "pre"){
#     if (toggle_nonsep == 0){
#       file.create("results/ep/sirt.txt")
#     } else {
#       file.create("results/ep/nsirt.txt")
#     }
#     sessions <- full_sessions
#   } else {
#     if (toggle_nonsep == 0){
#       done <- read.table("results/ep/sirt.txt") %>% 
#         pull(V2)
#       sessions <- c(1:5)[!(c(1:5) %in% done)]
#     } else {
#       done <- read.table("results/ep/nsirt.txt") %>% 
#         pull(V2)
#       sessions <- c(1:5)[!(c(1:5) %in% done)]
#     }
#   }
# } else {
#   if (toggle_nonsep == 0){
#     cirt_results <- read.table("results/ep/sirt.txt")
#   } else {
#     cirt_results <- read.table("results/ep/nsirt.txt")
#   }
#   sessions <- cirt_results %>% 
#     group_by(V2) %>% 
#     summarize(V5 = min(V5)) %>% 
#     filter(V5 >= 1.1) %>% 
#     pull(V2)
# }

for (s in sessions){
  if (run == "first"){
    dir.create(paste0("results/ep/", s), F, T)
  }
  
  if (s == 1){
    rc <- rcv_ep[[1]]
    mem_desc <- mp_ep_final %>% 
      filter(session == 1)
    
    id <- 1:length(unique(rc$MEPID))
    names(id) <- as.character(unique(rc$MEPID))
    mem_desc$mepid <- id[as.character(mem_desc$id)]
    
    mem_desc <- mem_desc %>% 
      filter(!(is.na(mepid))) %>% 
      arrange(mepid) %>% 
      mutate(lr = case_when(party_group %in% c("C", "E", "G") ~ 1,
                            party_group %in% c("M", "V") ~ -1,
                            TRUE ~ 0),
             eu = case_when(party_group %in% c("G", "M", "A", "V", "X") ~ 1,
                            party_group %in% c("S", "C") ~ -1,
                            TRUE ~ 0)) %>% 
      mutate(across(c("lr", "eu", "dim1", "dim2"), ~(.x-mean(.x, na.rm = T))/sd(.x, na.rm = T))) %>% 
      mutate(dim2 = dim2 * -1)
    
    pr_init <- matrix(NA, nrow = ncol(rc[,-c(1:5)]), ncol = 3)
    pr_init[,1] <- t(apply(rc[,-c(1:5)], 2, function(x) coef(summary(lm(x~1)))[1,1]))-.5
    pr_init[,2] <- t(apply(rc[,-c(1:5)], 2, function(x) coef(summary(lm(x~mem_desc$dim1)))[2,1]))
    pr_init[,3] <- t(apply(rc[,-c(1:5)], 2, function(x) coef(summary(lm(x~mem_desc$dim2)))[2,1]))
    
    initF <- function() {
      list(
        gamma = rep(2,2),
        L = chol(matrix(c(1,0,0,1), ncol = 2)),
        theta_free = matrix(0, ncol = 2, nrow = nrow(rc)),
        bk_free = apply(pr_init[,2:3], 2, function(x) x / sd(x)) + rnorm(ncol(rc[,-c(1:5)])*2, 0, .25),
        ak = pr_init[,1] / sd(pr_init[,1]) + rnorm(ncol(rc[,-c(1:5)]), 0, .25),
        theta_mu = rnorm(2, rep(0, 2), .25)
      )
    }
    
    theta_in <- abind(as.matrix(cbind(mem_desc[,c("lr")], 0), ncol = 2), 
                      as.matrix(cbind(mem_desc[,c("eu")], 0), ncol = 2),
                      along = 3)
    
    rc[is.na(rc)] <- -1
    
    rc_c <- matrix(rep(1:ncol(rc[,-c(1:5)]), each = nrow(rc[,-c(1:5)])), nrow(rc[,-c(1:5)]), ncol(rc[,-c(1:5)]))
    rc_r <- matrix(rep(1:nrow(rc[,-c(1:5)]), each = ncol(rc[,-c(1:5)])), nrow(rc[,-c(1:5)]), ncol(rc[,-c(1:5)]), byrow = T)
    
    stan_data <- list(
      I = ncol(rc[,-c(1:5)]), # proposal
      J = nrow(rc[,-c(1:5)]), # legislator
      J_old = 2,
      N = sum(rc[,-c(1:5)] != -1),
      ii = rc_c[rc[,-c(1:5)] != -1],
      jj = rc_r[rc[,-c(1:5)] != -1],
      Y = as.matrix(rc[,-c(1:5)]), # yea
      incl_nonsep = toggle_nonsep,
      K = 2,
      T = 0,
      theta_in = theta_in,
      theta_old = matrix(c(1,2), ncol = 2, nrow = 2),
      theta_old_ind = c(1,2),
      D = 2)
  } else {
    if (toggle_nonsep == 0){
      load(paste0("results/ep/", s-1, "/sirt.Rda"))
    } else {
      load(paste0("results/ep/", s-1, "/nsirt.Rda"))
    }
    mem_desc$dim2_o <- mem_desc$dim1_o <- NULL
    mem_old <- mem_desc %>% rename(dim1_o = dim1_n, dim2_o = dim2_n)
    rm(fit_sum, loo_sum, mem_desc)
    
    rc <- rcv_ep[[s]]
    mem_desc <- mp_ep_final %>% 
      filter(session == s)
    
    id <- 1:length(unique(rc$MEPID))
    names(id) <- as.character(unique(rc$MEPID))
    mem_desc$mepid <- id[as.character(mem_desc$id)]
    
    mem_desc <- mem_desc %>% 
      filter(!(is.na(mepid))) %>% 
      arrange(mepid) %>% 
      mutate(lr = case_when(party_group %in% c("C", "E", "G") ~ 1,
                            party_group %in% c("M", "V") ~ -1,
                            TRUE ~ 0),
             eu = case_when(party_group %in% c("G", "M", "A", "V", "X") ~ 1,
                            party_group %in% c("S", "C") ~ -1,
                            TRUE ~ 0)) %>% 
      merge(mem_old %>% select(id, dim1_o, dim2_o), all.x = TRUE, sort = FALSE) %>% 
      mutate(new = ifelse(is.na(dim1_o), 1, 0)) %>% 
      mutate(across(c("lr", "eu", "dim1_o", "dim2_o"), ~(.x-mean(.x, na.rm = T))/sd(.x, na.rm = T))) %>% 
      group_by(lr) %>% 
      mutate(lr_shift = dim1_o - mean(dim1_o, na.rm = T)) %>% 
      ungroup() %>% 
      group_by(eu) %>% 
      mutate(eu_shift = dim2_o - mean(dim2_o, na.rm = T)) %>% 
      ungroup() %>% 
      arrange(mepid)
    
    gamma_init <- rbind(coef(summary(lm(mem_desc$dim1_o ~ mem_desc$lr)))[,1],
                        coef(summary(lm(mem_desc$dim2_o ~ mem_desc$eu)))[,1])
    
    pr_init <- matrix(NA, nrow = ncol(rc[,-c(1:5)]), ncol = 3)
    pr_init[,1] <- t(apply(rc[,-c(1:5)], 2, function(x) coef(summary(lm(x~1)))[1,1]))-.5
    pr_init[,2] <- t(apply(rc[,-c(1:5)], 2, function(x) coef(summary(lm(x~mem_desc$dim1_o)))[2,1]))
    pr_init[,3] <- t(apply(rc[,-c(1:5)], 2, function(x) coef(summary(lm(x~mem_desc$dim2_o)))[2,1]))
    
    mem_desc$dim1_o[mem_desc$new == 1] <- rnorm(sum(mem_desc$new), gamma_init[1,1] + gamma_init[1,2]*mem_desc$lr[mem_desc$new == 1], .25)
    mem_desc$dim2_o[mem_desc$new == 1] <- rnorm(sum(mem_desc$new), gamma_init[2,1] + gamma_init[2,2]*mem_desc$eu[mem_desc$new == 1], .25)
    
    initF <- function() {
      list(
        gamma = rep(2,2),
        L = chol(matrix(c(1,0,0,1), ncol = 2)),
        theta_free = matrix(0, ncol = 2, nrow = nrow(rc[,-c(1:5)])),
        bk_free = apply(pr_init[,2:3], 2, function(x) x / sd(x)) + rnorm(ncol(rc[,-c(1:5)])*2, 0, .25),
        ak = pr_init[,1] / sd(pr_init[,1]) + rnorm(ncol(rc[,-c(1:5)]), 0, .25)
      )
    }
    
    mem_desc <- mem_desc %>% 
      mutate(across(c("lr_shift", "eu_shift"), ~ (.x / sd(.x, na.rm = T)))) %>% 
      group_by(lr) %>% 
      mutate(lr_shift = ifelse(new == 1, 0, lr_shift)) %>% 
      ungroup() %>% 
      group_by(eu) %>% 
      mutate(eu_shift = ifelse(new == 1, 0, eu_shift)) %>% 
      ungroup() %>% 
      arrange(mepid)
    
    theta_in <- abind(as.matrix(mem_desc[,c("lr", "lr_shift")], ncol = 2),
                      as.matrix(mem_desc[,c("eu", "eu_shift")], ncol = 2),
                      along = 3)
    
    rc[is.na(rc)] <- -1
    
    rc_c <- matrix(rep(1:ncol(rc[,-c(1:5)]), each = nrow(rc[,-c(1:5)])), nrow(rc[,-c(1:5)]), ncol(rc[,-c(1:5)]))
    rc_r <- matrix(rep(1:nrow(rc[,-c(1:5)]), each = ncol(rc[,-c(1:5)])), nrow(rc[,-c(1:5)]), ncol(rc[,-c(1:5)]), byrow = T)
    
    stan_data <- list(
      I = ncol(rc[,-c(1:5)]), # proposal
      J = nrow(rc[,-c(1:5)]), # legislator
      N = sum(rc[,-c(1:5)] != -1),
      ii = rc_c[rc[,-c(1:5)] != -1],
      jj = rc_r[rc[,-c(1:5)] != -1],
      J_old = length(mem_desc$mepid[mem_desc$new != 1]),
      Y = as.matrix(rc[,-c(1:5)]), # yea
      incl_nonsep = toggle_nonsep,
      K = 2,
      T = 1,
      theta_in = theta_in,
      theta_old = mem_desc[mem_desc$new != 1, c("dim1_o", "dim2_o")],
      theta_old_ind = mem_desc$mepid[mem_desc$new != 1],
      D = 2)
  }
  
  mod <- cmdstan_model("mod_dyn.stan")
  
  fit <- mod$sample(data = stan_data, init = initF, iter_warmup = 500, iter_sampling = 500, chains = 8, seed = 2021)
  
  end <- Sys.time()
  
  # extract results
  fit_sum <- (fit$summary(c("gamma_post","ak", "sal", "weights", "bk", "theta", "r_old", "r_dim", "r_sq", "m_prop_compl", "sd_prop_compl", "p_share", "DL")))
  fit_sum$parameter <- as.factor(gsub("\\[.*]", "", fit_sum$variable))

  r <- max(fit_sum$rhat[!(is.na(fit_sum$rhat)) & !(is.nan(fit_sum$rhat)) & fit_sum$parameter %in% c("ak", "weights", "bk", "theta")])
  
  gc();Sys.sleep(15);gc()
  if (s > 3){
    loo_sum <- fit$draws("log_lik")
  } else {
    loo_sum <- fit$loo()
  }
  tmp <- matrix(fit_sum$mean[fit_sum$parameter == "theta"], ncol = 2)
  
  mem_desc <- mem_desc %>% mutate(dim1_n = tmp[,1], dim2_n = tmp[,2])
  
  if (toggle_nonsep == 0){
    save(fit_sum, loo_sum, mem_desc,
         file = paste0("results/ep/", s, "/sirt.Rda"))
  } else {
    save(fit_sum, loo_sum, mem_desc,
         file = paste0("results/ep/", s, "/nsirt.Rda"))
  }
  rm(fit);rm(loo_sum)
  gc();Sys.sleep(60);gc()
  if (toggle_nonsep == 0){
    write(paste0("session ", s, " ", end, " ", round(r, 3)), file = "results/ep/sirt.txt", append = TRUE)
  } else {
    write(paste0("session ", s, " ", end, " ", round(r, 3)), file = "results/ep/nsirt.txt", append = TRUE)
  }
}

gc();Sys.sleep(60);gc()
# rstudioapi::restartSession(command = "")
# gc();
Sys.sleep(60);gc()

for (s in 4:5){
  if (toggle_nonsep == 0){
    load(paste0("results/ep/", s, "/sirt.Rda"))
  } else {
    load(paste0("results/ep/", s, "/nsirt.Rda"))
  }
  if (!(loo::is.loo(loo_sum))){
    loo_sum <- loo::loo(loo_sum) 
  }
  gc()
  if (toggle_nonsep == 0){
    save(fit_sum, loo_sum, mem_desc,
         file = paste0("results/ep/", s, "/sirt.Rda"))
  } else {
    save(fit_sum, loo_sum, mem_desc,
         file = paste0("results/ep/", s, "/nsirt.Rda"))
  }
}